import matplotlib as mpl
mpl.use('Agg')

import os
import torch
import torch.nn as nn

from utils.model_normalization import Cifar10Wrapper
import utils.datasets as dl
import utils.models.model_factory_32 as factory
import utils.run_file_helpers as rh
from distutils.util import strtobool
import ssl_utils as ssl
import utils.train_types as tt


import argparse

parser = argparse.ArgumentParser(description='Define hyperparameters.', prefix_chars='-')
parser.add_argument('--net', type=str, default='ResNet18', help='Resnet18, 34 or 50, WideResNet28')
parser.add_argument('--model_params', nargs='+', default=[])
parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10 or semi-cifar10')
parser.add_argument('--cifar_subset', type=int, default='0', help='Use subset of X cifar images')
parser.add_argument('--od_dataset', type=str, default='tinyImages',
                    help=('tinyImages or cifar100'))
parser.add_argument('--exclude_cifar', dest='exclude_cifar', type=lambda x: bool(strtobool(x)),
                    default=True, help='whether to exclude cifar10 from tiny images')
parser.add_argument('--CEDA_label_smoothing', default=0, type=float,help='Whether to weight the od')
parser.add_argument('--semi_ratio', type=int,
                    default=0, help='Fixed ratio or variable (0)')
parser.add_argument('--samples', type=int,
                    default=25000, help='Max additional samples per class')
parser.add_argument('--unlabeled_samples', type=int,
                    default=1_000_000, help='Max additional samples per class')
parser.add_argument('--teacher', type=str,
                    default=None, help='Teacher density_model')
parser.add_argument('--threshold', type=str,
                    default='0.998', help='TPR threshold')
parser.add_argument('--od_threshold', type=str,
                    default='same', help='OD threshold')
parser.add_argument('--calibrate_temperature', dest='calibrate_temperature', type=lambda x: bool(strtobool(x)),
                    default=True, help='whether to use temperature calibration')

rh.parser_add_commons(parser)
rh.parser_add_adversarial_commons(parser)
rh.parser_add_adversarial_norms(parser, 'cifar10')

hps = parser.parse_args()
#
device_ids = None
if len(hps.gpu)==0:
    device = torch.device('cpu')
    print('Warning! Computing on CPU')
elif len(hps.gpu)==1:
    device = torch.device('cuda:' + str(hps.gpu[0]))
else:
    device_ids = [int(i) for i in hps.gpu]
    device = torch.device('cuda:' + str(min(device_ids)))

# parameters
# https://arxiv.org/pdf/1906.09453.pdf
t_obj = 'kl'
lr = hps.lr
bs = hps.bs
epochs = hps.epochs
lam = 1.0

network_name = hps.net.lower()
augm = hps.augm.lower()
exclude_cifar = hps.exclude_cifar
nesterov = hps.nesterov
od_dataset = hps.od_dataset
ceda_label_smoothing = hps.CEDA_label_smoothing
warmup_epochs = hps.warmup_epochs
test_epochs = hps.test_epochs

num_classes = 10

#Load density_model
img_size = 32
model_root_dir = 'Cifar10Models'
logs_root_dir = 'Cifar10Logs'

model, model_name, model_config = factory.build_model(network_name, num_classes)
model_dir = os.path.join(model_root_dir, model_name)
log_dir = os.path.join(logs_root_dir, model_name)

# load dataset
od_bs = int(hps.od_bs_factor * bs)

class_tpr_min = hps.threshold
if hps.od_threshold == 'same':
    od_exclusion_threshold = class_tpr_min
elif hps.od_threshold == 'none':
    od_exclusion_threshold = None
else:
    od_exclusion_threshold = hps.od_threshold
cutout_window = 16

dataset_classifications_path = ssl.get_dataset_classification_dir('cifar10')
epoch_subdivs = 1

msda_config = rh.create_msda_config(hps)

if hps.train_type in ['plainKL', 'CEDATargetedKL', 'CEDAKL', 'CEDATargeted', 'CEDATargetedKLEntropy']:
    samples_per_class = hps.samples
    calibrate_temperature = hps.calibrate_temperature
    verbose_exclude = False

    teacher_model = hps.teacher
    selection_model = None

    ssl_config = {}
    id_config = {}
    od_config = {}

    if hps.train_type in  ['plainKL']:
        loader_config = {'SSL config': ssl_config, 'ID config': id_config}
    else:
        loader_config = {'SSL config': ssl_config,'ID config': id_config, 'OD config': od_config}

    if hps.train_type in  ['plainKL', 'CEDATargetedKL', 'CEDAKL'] :
        soft_labels = True
    elif hps.train_type == 'CEDATargeted':
        soft_labels = False
    else:
        raise NotImplementedError()

    if od_dataset == 'tinyImages':
        if hps.cifar_subset <= 0:
            train_loader, od_loader = ssl.get_tiny_cifar_partition(dataset_classifications_path, teacher_model, 'cifar10',
                                                                   samples_per_class, False, semi_ratio=hps.semi_ratio,
                                                                   class_tpr_min=class_tpr_min,
                                                                   od_exclusion_threshold=od_exclusion_threshold,
                                                                   calibrate_temperature=calibrate_temperature,
                                                                   verbose_exclude=verbose_exclude,
                                                                   soft_labels=soft_labels, batch_size=bs,
                                                                   augm_type=augm, aa_magnitude=1.0, size=img_size,
                                                                   exclude_cifar=exclude_cifar,
                                                                   exclude_cifar10_1=exclude_cifar,
                                                                   id_config_dict=id_config, od_config_dict=od_config,
                                                                   ssl_config=ssl_config)
        else:
            raise NotImplementedError('use tinyImages_subset ')
            # train_loader, od_loader = ssl.get_tiny_cifar_subset_partition(dataset_classifications_path,
            #                                                                   teacher_model, 'cifar10',
            #                                                                   hps.cifar_subset // 10,
            #                                                                   samples_per_class,
            #                                                                   hps.unlabeled_samples,
            #                                                                   semi_ratio=hps.semi_ratio,
            #                                                                   class_tpr_min=class_tpr_min,
            #                                                                   od_exclusion_threshold=od_exclusion_threshold,
            #                                                                   calibrate_temperature=calibrate_temperature,
            #                                                                   verbose_exclude=verbose_exclude,
            #                                                                   soft_labels=soft_labels,
            #                                                                   batch_size=bs, augm_type=augm,
            #                                                                   size=img_size,
            #                                                                   aa_magnitude=1.0,
            #                                                                   exclude_cifar=exclude_cifar,
            #                                                                   exclude_cifar10_1=exclude_cifar,
            #                                                                   id_config_dict=id_config,
            #                                                                   od_config_dict=od_config,
            #                                                                   ssl_config=ssl_config)

    else:
        train_loader, od_loader = ssl.get_cifar_subset_plus_od_partition(teacher_model, 'cifar10', od_dataset,
                                                                         hps.cifar_subset // 10, samples_per_class, hps.unlabeled_samples,
                                                                         semi_ratio=hps.semi_ratio, class_tpr_min=class_tpr_min,
                                                                          od_exclusion_threshold=od_exclusion_threshold,
                                                                          calibrate_temperature=calibrate_temperature,
                                                                          verbose_exclude=verbose_exclude,
                                                                          soft_labels=soft_labels,
                                                                          batch_size=bs, augm_type=augm,
                                                                          size=img_size,
                                                                          aa_magnitude=1.0,
                                                                          id_config_dict=id_config,
                                                                          od_config_dict=od_config,
                                                                          ssl_config=ssl_config)
elif hps.train_type in ['CEDA', 'CEDAExtra']:
    id_config = {}
    od_config = {}
    loader_config = {'ID config': id_config, 'OD config': od_config}

    if hps.dataset == 'cifar10':
        if hps.cifar_subset <= 0:
            train_loader = dl.get_CIFAR10(train=True, batch_size=bs, augm_type=augm,
                                          size=img_size, config_dict=id_config)
        else:
            train_loader = ssl.get_CIFAR10_subset('train', hps.cifar_subset / 10, batch_size=bs, augm_type=hps.augm,
                                                  shuffle=True, size=img_size, id_config=id_config)
    elif hps.dataset == 'cifar10_ti_500k':
        train_loader = dl.get_CIFAR10_ti_500k(train=True, batch_size=bs, augm_type=augm,
                                              config_dict=id_config)
    else:
        raise ValueError(f'Dataset {hps.datset} not supported')

    if od_dataset == 'tinyImages':
        od_loader = dl.get_80MTinyImages(batch_size=od_bs, augm_type=augm, num_workers=1, size=img_size,
                                         exclude_cifar=exclude_cifar, exclude_cifar10_1=exclude_cifar,
                                         config_dict=od_config)
    else:
        od_loader = ssl.get_CIFAR10_subset_plus_OD('unlabeled', hps.cifar_subset / 10, od_dataset, hps.unlabeled_samples,
                                                   batch_size=od_bs, augm_type=augm, num_workers=8,
                                                   size=img_size, config_file=od_config)
else:
    id_config = {}
    loader_config = {'ID config': id_config}
    if hps.dataset == 'cifar10':
        if hps.cifar_subset <= 0:
            train_loader = dl.get_CIFAR10(train=True, batch_size=bs, augm_type=augm,
                                      size=img_size, config_dict=id_config)
        else:
            train_loader = ssl.get_CIFAR10_subset('train', hps.cifar_subset / 10, batch_size=bs, augm_type=hps.augm,
                                                  shuffle=True, size=img_size, id_config=id_config)

    elif hps.dataset == 'cifar10_ti_500k':
        train_loader = dl.get_CIFAR10_ti_500k(train=True, batch_size=bs, augm_type=augm,
                                              config_dict=id_config)

    else:
        raise ValueError(f'Dataset {hps.datset} not supported')

if hps.cifar_subset <= 0:
    test_loader = dl.get_CIFAR10_1(batch_size=bs)
else:
    test_loader = ssl.get_CIFAR10_subset('val', hps.cifar_subset / 10, batch_size=bs, augm_type='none',
                                                  shuffle=True, size=img_size)
extra_test_loader = dl.get_CIFAR10(train=False, batch_size=bs, augm_type='none')

scheduler_config, optimizer_config = rh.create_optim_scheduler_swa_configs(hps)
total_epochs = epochs * epoch_subdivs

# load old density_model
if hps.continue_trained is not None:
    load_folder = hps.continue_trained[0]
    load_epoch = hps.continue_trained[1]
    start_epoch = int(int(hps.continue_trained[2]))# * epoch_subdivs)
    if load_epoch in ['final', 'best', 'final_swa', 'best_swa']:
        state_dict_file = f'{model_dir}/{load_folder}/{load_epoch}.pth'
        optimizer_dict_file = f'{model_dir}/{load_folder}/{load_epoch}_optim.pth'
    else:
        state_dict_file = f'{model_dir}/{load_folder}/checkpoints/{load_epoch}.pth'
        optimizer_dict_file = f'{model_dir}/{load_folder}/checkpoints/{load_epoch}_optim.pth'

    state_dict = torch.load(state_dict_file, map_location=device)

    if hps.continue_optim:
        try:
            optim_state_dict = torch.load(optimizer_dict_file, map_location=device)
        except:
            print('Warning: Could not load Optim State - Restarting optim')
            optim_state_dict = None
    else:
        optim_state_dict = None

    model.load_state_dict(state_dict)

    print(f'Continuing {load_folder} from epoch {load_epoch} - Starting training at epoch {start_epoch}')
else:
    start_epoch = 0
    optim_state_dict = None

model = Cifar10Wrapper(model).to(device)

if len(hps.gpu) > 1:
    model = nn.DataParallel(model, device_ids=device_ids)


#Train Type
if hps.train_type == 'plain':
    trainer = tt.PlainTraining(model, optimizer_config, total_epochs, device, num_classes,
                               lr_scheduler_config=scheduler_config,
                               msda_config=msda_config, test_epochs=test_epochs,
                               saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type == 'CEDA' or hps.train_type == 'CEDA':
    trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                              lr_scheduler_config=scheduler_config, msda_config=msda_config,
                              train_obj=t_obj, od_weight=lam, test_epochs=test_epochs,
                              saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type == 'CEDATargeted':
    CEDA_VARIANT = {'Type': 'CEDATargeted', 'LabelSmoothingEps' : None if ceda_label_smoothing == 0 else ceda_label_smoothing }
    trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                              CEDA_variant=CEDA_VARIANT,
                              lr_scheduler_config=scheduler_config, msda_config=msda_config,
                              train_obj=t_obj, od_weight=lam, test_epochs=test_epochs,
                              saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type == 'CEDATargetedKL':
    CEDA_VARIANT = {'Type': 'CEDATargeted', 'LabelSmoothingEps' : None if ceda_label_smoothing == 0 else ceda_label_smoothing }
    trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                              CEDA_variant=CEDA_VARIANT, lr_scheduler_config=scheduler_config, msda_config=msda_config,
                              clean_criterion='kl', train_obj=t_obj, od_weight=lam,
                              test_epochs=test_epochs, saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type == 'CEDAKL':
    trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                              lr_scheduler_config=scheduler_config, msda_config=msda_config,
                              clean_criterion='kl', train_obj=t_obj, od_weight=lam,
                              test_epochs=test_epochs, saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type == 'plainKL':
    trainer = tt.PlainTraining(model, optimizer_config, total_epochs, device,  num_classes,
                               lr_scheduler_config=scheduler_config, msda_config=msda_config, clean_criterion='kl',
                               test_epochs=test_epochs, saved_model_dir=model_dir, saved_log_dir=log_dir)
else:
    raise ValueError('Train type {} is not supported'.format(hps.train_type))

##DEBUG:
# torch.autograd.set_detect_anomaly(True)


torch.backends.cudnn.benchmark = True
if trainer.requires_out_distribution():

    # od_loader = dl.TinyImages('CIFAR10', batch_size=train_bs, shuffle=True, train=True)
    # od_loader = dl.TinyImagesOffsetLinear(batch_size=train_bs, augm=True)
    train_loaders, test_loaders = trainer.create_loaders_dict(train_loader, test_loader=test_loader,
                                                              extra_test_loaders=[extra_test_loader],
                                                              out_distribution_loader=od_loader)
    trainer.train(train_loaders, test_loaders, loader_config=loader_config, start_epoch=start_epoch,
                  optim_state_dict=optim_state_dict)

    # od_noise_dataset = dl.SmoothNoiseDataset(1.0, 2.5, (3, 32, 32), len(trainset))
    # od_noise_loader = torch.utils.ref_data.DataLoader(od_noise_dataset, batch_size=train_bs, shuffle=True, num_workers=8)
else:
    train_loaders, test_loaders = trainer.create_loaders_dict(train_loader, test_loader=test_loader,
                                                              extra_test_loaders=[extra_test_loader]
                                                              )
    trainer.train(train_loaders, test_loaders, loader_config=loader_config, start_epoch=start_epoch,
                  optim_state_dict=optim_state_dict)



